Environment setup¶
Step 1: Install Required Libraries¶
In [ ]:
# Uninstall conflicting packages
!pip uninstall -y keras-tf tensorflowjs tensorflow
# Install required libraries
!pip install pymongo numpy pandas matplotlib scikit-learn tensorflow==2.15.0 tensorflowjs
WARNING: Skipping keras-tf as it is not installed. WARNING: Skipping tensorflowjs as it is not installed. Found existing installation: tensorflow 2.18.0 Uninstalling tensorflow-2.18.0: Successfully uninstalled tensorflow-2.18.0 Collecting pymongo Downloading pymongo-4.11.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (22 kB) Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (2.0.2) Requirement already satisfied: pandas in /usr/local/lib/python3.11/dist-packages (2.2.2) Requirement already satisfied: matplotlib in /usr/local/lib/python3.11/dist-packages (3.10.0) Requirement already satisfied: scikit-learn in /usr/local/lib/python3.11/dist-packages (1.6.1) Collecting tensorflow==2.15.0 Downloading tensorflow-2.15.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.4 kB) Collecting tensorflowjs Downloading tensorflowjs-4.22.0-py3-none-any.whl.metadata (3.2 kB) Requirement already satisfied: absl-py>=1.0.0 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.15.0) (1.4.0) Requirement already satisfied: astunparse>=1.6.0 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.15.0) (1.6.3) Requirement already satisfied: flatbuffers>=23.5.26 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.15.0) (25.2.10) Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.15.0) (0.6.0) Requirement already satisfied: google-pasta>=0.1.1 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.15.0) (0.2.0) Requirement already satisfied: h5py>=2.9.0 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.15.0) (3.13.0) Requirement already satisfied: libclang>=13.0.0 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.15.0) (18.1.1) Collecting ml-dtypes~=0.2.0 (from tensorflow==2.15.0) Downloading ml_dtypes-0.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB) Collecting numpy Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 61.0/61.0 kB 4.0 MB/s eta 0:00:00 Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.15.0) (3.4.0) Requirement already satisfied: packaging in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.15.0) (24.2) Collecting protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3 (from tensorflow==2.15.0) Downloading protobuf-4.25.6-cp37-abi3-manylinux2014_x86_64.whl.metadata (541 bytes) Requirement already satisfied: setuptools in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.15.0) (75.2.0) Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.15.0) (1.17.0) Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.15.0) (2.5.0) Requirement already satisfied: typing-extensions>=3.6.6 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.15.0) (4.13.0) Collecting wrapt<1.15,>=1.11.0 (from tensorflow==2.15.0) Downloading wrapt-1.14.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB) Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.15.0) (0.37.1) Requirement already satisfied: grpcio<2.0,>=1.24.3 in /usr/local/lib/python3.11/dist-packages (from tensorflow==2.15.0) (1.71.0) Collecting tensorboard<2.16,>=2.15 (from tensorflow==2.15.0) Downloading tensorboard-2.15.2-py3-none-any.whl.metadata (1.7 kB) Collecting tensorflow-estimator<2.16,>=2.15.0 (from tensorflow==2.15.0) Downloading tensorflow_estimator-2.15.0-py2.py3-none-any.whl.metadata (1.3 kB) Collecting keras<2.16,>=2.15.0 (from tensorflow==2.15.0) Downloading keras-2.15.0-py3-none-any.whl.metadata (2.4 kB) Collecting dnspython<3.0.0,>=1.16.0 (from pymongo) Downloading dnspython-2.7.0-py3-none-any.whl.metadata (5.8 kB) Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas) (2.8.2) Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas) (2025.2) Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas) (2025.2) Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib) (1.3.1) Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.11/dist-packages (from matplotlib) (0.12.1) Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib) (4.56.0) Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib) (1.4.8) Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.11/dist-packages (from matplotlib) (11.1.0) Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib) (3.2.3) Requirement already satisfied: scipy>=1.6.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn) (1.14.1) Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn) (1.4.2) Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn) (3.6.0) Requirement already satisfied: flax>=0.7.2 in /usr/local/lib/python3.11/dist-packages (from tensorflowjs) (0.10.4) Requirement already satisfied: importlib_resources>=5.9.0 in /usr/local/lib/python3.11/dist-packages (from tensorflowjs) (6.5.2) Requirement already satisfied: jax>=0.4.13 in /usr/local/lib/python3.11/dist-packages (from tensorflowjs) (0.5.2) Requirement already satisfied: jaxlib>=0.4.13 in /usr/local/lib/python3.11/dist-packages (from tensorflowjs) (0.5.1) Requirement already satisfied: tf-keras>=2.13.0 in /usr/local/lib/python3.11/dist-packages (from tensorflowjs) (2.18.0) Collecting tensorflow-decision-forests>=1.5.0 (from tensorflowjs) Downloading tensorflow_decision_forests-1.12.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.3 kB) Requirement already satisfied: tensorflow-hub>=0.16.1 in /usr/local/lib/python3.11/dist-packages (from tensorflowjs) (0.16.1) Collecting packaging (from tensorflow==2.15.0) Downloading packaging-23.2-py3-none-any.whl.metadata (3.2 kB) Requirement already satisfied: wheel<1.0,>=0.23.0 in /usr/local/lib/python3.11/dist-packages (from astunparse>=1.6.0->tensorflow==2.15.0) (0.45.1) Requirement already satisfied: msgpack in /usr/local/lib/python3.11/dist-packages (from flax>=0.7.2->tensorflowjs) (1.1.0) Requirement already satisfied: optax in /usr/local/lib/python3.11/dist-packages (from flax>=0.7.2->tensorflowjs) (0.2.4) Requirement already satisfied: orbax-checkpoint in /usr/local/lib/python3.11/dist-packages (from flax>=0.7.2->tensorflowjs) (0.11.10) Requirement already satisfied: tensorstore in /usr/local/lib/python3.11/dist-packages (from flax>=0.7.2->tensorflowjs) (0.1.72) Requirement already satisfied: rich>=11.1 in /usr/local/lib/python3.11/dist-packages (from flax>=0.7.2->tensorflowjs) (13.9.4) Requirement already satisfied: PyYAML>=5.4.1 in /usr/local/lib/python3.11/dist-packages (from flax>=0.7.2->tensorflowjs) (6.0.2) Requirement already satisfied: treescope>=0.1.7 in /usr/local/lib/python3.11/dist-packages (from flax>=0.7.2->tensorflowjs) (0.1.9) INFO: pip is looking at multiple versions of jax to determine which version is compatible with other requirements. This could take a while. Collecting jax>=0.4.13 (from tensorflowjs) Downloading jax-0.5.3-py3-none-any.whl.metadata (22 kB) Collecting jaxlib>=0.4.13 (from tensorflowjs) Downloading jaxlib-0.5.3-cp311-cp311-manylinux2014_x86_64.whl.metadata (1.2 kB) Collecting jax>=0.4.13 (from tensorflowjs) Downloading jax-0.5.1-py3-none-any.whl.metadata (22 kB) Downloading jax-0.5.0-py3-none-any.whl.metadata (22 kB) Collecting jaxlib>=0.4.13 (from tensorflowjs) Downloading jaxlib-0.5.0-cp311-cp311-manylinux2014_x86_64.whl.metadata (978 bytes) Collecting jax>=0.4.13 (from tensorflowjs) Downloading jax-0.4.38-py3-none-any.whl.metadata (22 kB) Collecting jaxlib>=0.4.13 (from tensorflowjs) Downloading jaxlib-0.4.38-cp311-cp311-manylinux2014_x86_64.whl.metadata (1.0 kB) Collecting jax>=0.4.13 (from tensorflowjs) Downloading jax-0.4.37-py3-none-any.whl.metadata (22 kB) Collecting jaxlib>=0.4.13 (from tensorflowjs) Downloading jaxlib-0.4.36-cp311-cp311-manylinux2014_x86_64.whl.metadata (1.0 kB) Collecting jax>=0.4.13 (from tensorflowjs) Downloading jax-0.4.36-py3-none-any.whl.metadata (22 kB) Downloading jax-0.4.35-py3-none-any.whl.metadata (22 kB) Collecting jaxlib>=0.4.13 (from tensorflowjs) Downloading jaxlib-0.4.35-cp311-cp311-manylinux2014_x86_64.whl.metadata (983 bytes) INFO: pip is still looking at multiple versions of jax to determine which version is compatible with other requirements. This could take a while. Collecting jax>=0.4.13 (from tensorflowjs) Downloading jax-0.4.34-py3-none-any.whl.metadata (22 kB) Collecting jaxlib>=0.4.13 (from tensorflowjs) Downloading jaxlib-0.4.34-cp311-cp311-manylinux2014_x86_64.whl.metadata (983 bytes) Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.11/dist-packages (from tensorboard<2.16,>=2.15->tensorflow==2.15.0) (2.38.0) Requirement already satisfied: google-auth-oauthlib<2,>=0.5 in /usr/local/lib/python3.11/dist-packages (from tensorboard<2.16,>=2.15->tensorflow==2.15.0) (1.2.1) Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.11/dist-packages (from tensorboard<2.16,>=2.15->tensorflow==2.15.0) (3.7) Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.11/dist-packages (from tensorboard<2.16,>=2.15->tensorflow==2.15.0) (2.32.3) Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.11/dist-packages (from tensorboard<2.16,>=2.15->tensorflow==2.15.0) (0.7.2) Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from tensorboard<2.16,>=2.15->tensorflow==2.15.0) (3.1.3) INFO: pip is looking at multiple versions of tensorflow-decision-forests to determine which version is compatible with other requirements. This could take a while. Collecting tensorflow-decision-forests>=1.5.0 (from tensorflowjs) Downloading tensorflow_decision_forests-1.11.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.0 kB) Downloading tensorflow_decision_forests-1.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.3 kB) Downloading tensorflow_decision_forests-1.10.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.0 kB) Downloading tensorflow_decision_forests-1.9.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.0 kB) Downloading tensorflow_decision_forests-1.9.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.0 kB) Downloading tensorflow_decision_forests-1.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.0 kB) Downloading tensorflow_decision_forests-1.8.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.0 kB) Collecting wurlitzer (from tensorflow-decision-forests>=1.5.0->tensorflowjs) Downloading wurlitzer-3.1.1-py3-none-any.whl.metadata (2.5 kB) INFO: pip is looking at multiple versions of tf-keras to determine which version is compatible with other requirements. This could take a while. Collecting tf-keras>=2.13.0 (from tensorflowjs) Downloading tf_keras-2.19.0-py3-none-any.whl.metadata (1.8 kB) Downloading tf_keras-2.17.0-py3-none-any.whl.metadata (1.6 kB) Downloading tf_keras-2.16.0-py3-none-any.whl.metadata (1.6 kB) Downloading tf_keras-2.15.1-py3-none-any.whl.metadata (1.7 kB) Requirement already satisfied: cachetools<6.0,>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow==2.15.0) (5.5.2) Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.11/dist-packages (from google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow==2.15.0) (0.4.2) Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.11/dist-packages (from google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow==2.15.0) (4.9) Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.11/dist-packages (from google-auth-oauthlib<2,>=0.5->tensorboard<2.16,>=2.15->tensorflow==2.15.0) (2.0.0) Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow==2.15.0) (3.4.1) Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow==2.15.0) (3.10) Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow==2.15.0) (2.3.0) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow==2.15.0) (2025.1.31) Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.11/dist-packages (from rich>=11.1->flax>=0.7.2->tensorflowjs) (3.0.0) Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.11/dist-packages (from rich>=11.1->flax>=0.7.2->tensorflowjs) (2.18.0) Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.11/dist-packages (from werkzeug>=1.0.1->tensorboard<2.16,>=2.15->tensorflow==2.15.0) (3.0.2) Requirement already satisfied: chex>=0.1.87 in /usr/local/lib/python3.11/dist-packages (from optax->flax>=0.7.2->tensorflowjs) (0.1.89) Requirement already satisfied: etils[epy] in /usr/local/lib/python3.11/dist-packages (from optax->flax>=0.7.2->tensorflowjs) (1.12.2) Requirement already satisfied: nest_asyncio in /usr/local/lib/python3.11/dist-packages (from orbax-checkpoint->flax>=0.7.2->tensorflowjs) (1.6.0) Requirement already satisfied: humanize in /usr/local/lib/python3.11/dist-packages (from orbax-checkpoint->flax>=0.7.2->tensorflowjs) (4.12.2) Requirement already satisfied: simplejson>=3.16.0 in /usr/local/lib/python3.11/dist-packages (from orbax-checkpoint->flax>=0.7.2->tensorflowjs) (3.20.1) Collecting orbax-checkpoint (from flax>=0.7.2->tensorflowjs) Downloading orbax_checkpoint-0.11.10-py3-none-any.whl.metadata (2.0 kB) INFO: This is taking longer than usual. You might need to provide the dependency resolver with stricter constraints to reduce runtime. See https://pip.pypa.io/warnings/backtracking for guidance. If you want to abort this run, press Ctrl + C. Downloading orbax_checkpoint-0.11.9-py3-none-any.whl.metadata (2.0 kB) Downloading orbax_checkpoint-0.11.8-py3-none-any.whl.metadata (2.0 kB) Downloading orbax_checkpoint-0.11.7-py3-none-any.whl.metadata (2.0 kB) Downloading orbax_checkpoint-0.11.6-py3-none-any.whl.metadata (1.9 kB) Downloading orbax_checkpoint-0.11.5-py3-none-any.whl.metadata (1.9 kB) INFO: pip is looking at multiple versions of tensorstore to determine which version is compatible with other requirements. This could take a while. Collecting tensorstore (from flax>=0.7.2->tensorflowjs) Downloading tensorstore-0.1.73-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (21 kB) Downloading tensorstore-0.1.71-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (21 kB) Downloading tensorstore-0.1.69-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB) Downloading tensorstore-0.1.68-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB) Collecting orbax-checkpoint (from flax>=0.7.2->tensorflowjs) Downloading orbax_checkpoint-0.11.4-py3-none-any.whl.metadata (1.9 kB) INFO: pip is still looking at multiple versions of tensorstore to determine which version is compatible with other requirements. This could take a while. Collecting tensorstore (from flax>=0.7.2->tensorflowjs) Downloading tensorstore-0.1.67-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB) Downloading tensorstore-0.1.66-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB) Downloading tensorstore-0.1.65-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB) INFO: This is taking longer than usual. You might need to provide the dependency resolver with stricter constraints to reduce runtime. See https://pip.pypa.io/warnings/backtracking for guidance. If you want to abort this run, press Ctrl + C. Downloading tensorstore-0.1.64-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB) Downloading tensorstore-0.1.63-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB) Downloading tensorstore-0.1.62-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB) Downloading tensorstore-0.1.61-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB) Downloading tensorstore-0.1.60-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB) Downloading tensorstore-0.1.59-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB) Downloading tensorstore-0.1.58-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB) Downloading tensorstore-0.1.57-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB) Downloading tensorstore-0.1.56-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB) Downloading tensorstore-0.1.55-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB) Downloading tensorstore-0.1.54-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB) Downloading tensorstore-0.1.53-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB) Downloading tensorstore-0.1.52-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB) Downloading tensorstore-0.1.51-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB) Downloading tensorstore-0.1.50-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB) Downloading tensorstore-0.1.49-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB) Downloading tensorstore-0.1.48-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.9 kB) Downloading tensorstore-0.1.47-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.9 kB) Downloading tensorstore-0.1.46-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.9 kB) Downloading tensorstore-0.1.45-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.9 kB) Collecting orbax-checkpoint (from flax>=0.7.2->tensorflowjs) Downloading orbax_checkpoint-0.11.3-py3-none-any.whl.metadata (1.9 kB) Downloading orbax_checkpoint-0.11.2-py3-none-any.whl.metadata (1.9 kB) Downloading orbax_checkpoint-0.11.1-py3-none-any.whl.metadata (1.9 kB) Downloading orbax_checkpoint-0.11.0-py3-none-any.whl.metadata (1.9 kB) Downloading orbax_checkpoint-0.10.3-py3-none-any.whl.metadata (1.9 kB) Downloading orbax_checkpoint-0.10.2-py3-none-any.whl.metadata (1.9 kB) Downloading orbax_checkpoint-0.10.1-py3-none-any.whl.metadata (1.9 kB) Downloading orbax_checkpoint-0.10.0-py3-none-any.whl.metadata (1.9 kB) Downloading orbax_checkpoint-0.9.1-py3-none-any.whl.metadata (1.8 kB) Downloading orbax_checkpoint-0.9.0-py3-none-any.whl.metadata (1.8 kB) Downloading orbax_checkpoint-0.8.0-py3-none-any.whl.metadata (1.8 kB) Downloading orbax_checkpoint-0.7.0-py3-none-any.whl.metadata (1.8 kB) Downloading orbax_checkpoint-0.6.4-py3-none-any.whl.metadata (1.8 kB) Downloading orbax_checkpoint-0.6.3-py3-none-any.whl.metadata (1.8 kB) Downloading orbax_checkpoint-0.6.2-py3-none-any.whl.metadata (1.8 kB) Downloading orbax_checkpoint-0.6.1-py3-none-any.whl.metadata (1.8 kB) Downloading orbax_checkpoint-0.6.0-py3-none-any.whl.metadata (1.8 kB) Downloading orbax_checkpoint-0.5.23-py3-none-any.whl.metadata (1.8 kB) Downloading orbax_checkpoint-0.5.22-py3-none-any.whl.metadata (1.8 kB) Downloading orbax_checkpoint-0.5.21-py3-none-any.whl.metadata (1.8 kB) Downloading orbax_checkpoint-0.5.20-py3-none-any.whl.metadata (1.8 kB) Downloading orbax_checkpoint-0.5.19-py3-none-any.whl.metadata (1.8 kB) Downloading orbax_checkpoint-0.5.18-py3-none-any.whl.metadata (1.8 kB) Downloading orbax_checkpoint-0.5.17-py3-none-any.whl.metadata (1.8 kB) Downloading orbax_checkpoint-0.5.16-py3-none-any.whl.metadata (1.8 kB) Downloading orbax_checkpoint-0.5.15-py3-none-any.whl.metadata (1.8 kB) Downloading orbax_checkpoint-0.5.14-py3-none-any.whl.metadata (1.8 kB) Downloading orbax_checkpoint-0.5.13-py3-none-any.whl.metadata (1.8 kB) Downloading orbax_checkpoint-0.5.12-py3-none-any.whl.metadata (1.8 kB) Downloading orbax_checkpoint-0.5.11-py3-none-any.whl.metadata (1.8 kB) Downloading orbax_checkpoint-0.5.10-py3-none-any.whl.metadata (1.8 kB) Downloading orbax_checkpoint-0.5.9-py3-none-any.whl.metadata (1.7 kB) Downloading orbax_checkpoint-0.5.8-py3-none-any.whl.metadata (1.7 kB) Downloading orbax_checkpoint-0.5.7-py3-none-any.whl.metadata (1.7 kB) Downloading orbax_checkpoint-0.5.6-py3-none-any.whl.metadata (1.7 kB) Downloading orbax_checkpoint-0.5.5-py3-none-any.whl.metadata (1.7 kB) Downloading orbax_checkpoint-0.5.4-py3-none-any.whl.metadata (1.7 kB) Downloading orbax_checkpoint-0.5.3-py3-none-any.whl.metadata (1.7 kB) Downloading orbax_checkpoint-0.5.2-py3-none-any.whl.metadata (1.7 kB) Downloading orbax_checkpoint-0.5.1-py3-none-any.whl.metadata (1.7 kB) Downloading orbax_checkpoint-0.5.0-py3-none-any.whl.metadata (1.7 kB) Downloading orbax_checkpoint-0.4.8-py3-none-any.whl.metadata (1.7 kB) Downloading orbax_checkpoint-0.4.7-py3-none-any.whl.metadata (1.7 kB) Downloading orbax_checkpoint-0.4.6-py3-none-any.whl.metadata (1.7 kB) Downloading orbax_checkpoint-0.4.5-py3-none-any.whl.metadata (1.7 kB) Downloading orbax_checkpoint-0.4.4-py3-none-any.whl.metadata (1.7 kB) Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.11/dist-packages (from chex>=0.1.87->optax->flax>=0.7.2->tensorflowjs) (0.12.1) Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.11/dist-packages (from markdown-it-py>=2.2.0->rich>=11.1->flax>=0.7.2->tensorflowjs) (0.1.2) Requirement already satisfied: pyasn1<0.7.0,>=0.6.1 in /usr/local/lib/python3.11/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow==2.15.0) (0.6.1) Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.11/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<2,>=0.5->tensorboard<2.16,>=2.15->tensorflow==2.15.0) (3.2.2) Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from etils[epath,epy]->orbax-checkpoint->flax>=0.7.2->tensorflowjs) (2025.3.0) Requirement already satisfied: zipp in /usr/local/lib/python3.11/dist-packages (from etils[epath,epy]->orbax-checkpoint->flax>=0.7.2->tensorflowjs) (3.21.0) Downloading tensorflow-2.15.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (475.3 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 475.3/475.3 MB 3.5 MB/s eta 0:00:00 Downloading pymongo-4.11.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.4 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.4/1.4 MB 47.4 MB/s eta 0:00:00 Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 18.3/18.3 MB 40.3 MB/s eta 0:00:00 Downloading tensorflowjs-4.22.0-py3-none-any.whl (89 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 89.1/89.1 kB 6.2 MB/s eta 0:00:00 Downloading dnspython-2.7.0-py3-none-any.whl (313 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 313.6/313.6 kB 14.5 MB/s eta 0:00:00 Downloading jax-0.4.34-py3-none-any.whl (2.1 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/2.1 MB 53.0 MB/s eta 0:00:00 Downloading jaxlib-0.4.34-cp311-cp311-manylinux2014_x86_64.whl (86.1 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 86.1/86.1 MB 8.7 MB/s eta 0:00:00 Downloading keras-2.15.0-py3-none-any.whl (1.7 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.7/1.7 MB 61.1 MB/s eta 0:00:00 Downloading ml_dtypes-0.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/1.0 MB 41.6 MB/s eta 0:00:00 Downloading packaging-23.2-py3-none-any.whl (53 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 53.0/53.0 kB 3.1 MB/s eta 0:00:00 Downloading protobuf-4.25.6-cp37-abi3-manylinux2014_x86_64.whl (294 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 294.6/294.6 kB 20.2 MB/s eta 0:00:00 Downloading tensorboard-2.15.2-py3-none-any.whl (5.5 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.5/5.5 MB 81.4 MB/s eta 0:00:00 Downloading tensorflow_decision_forests-1.8.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (15.3 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 15.3/15.3 MB 77.5 MB/s eta 0:00:00 Downloading tensorflow_estimator-2.15.0-py2.py3-none-any.whl (441 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 442.0/442.0 kB 29.3 MB/s eta 0:00:00 Downloading tf_keras-2.15.1-py3-none-any.whl (1.7 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.7/1.7 MB 61.8 MB/s eta 0:00:00 Downloading wrapt-1.14.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (78 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 78.4/78.4 kB 6.4 MB/s eta 0:00:00 Downloading tensorstore-0.1.45-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.5 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.5/13.5 MB 81.7 MB/s eta 0:00:00 Downloading orbax_checkpoint-0.4.4-py3-none-any.whl (123 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 124.0/124.0 kB 10.3 MB/s eta 0:00:00 Downloading wurlitzer-3.1.1-py3-none-any.whl (8.6 kB) Installing collected packages: wurlitzer, wrapt, tensorflow-estimator, protobuf, packaging, numpy, keras, dnspython, tensorstore, pymongo, ml-dtypes, jaxlib, tensorboard, jax, tensorflow, orbax-checkpoint, tf-keras, tensorflow-decision-forests, tensorflowjs Attempting uninstall: wrapt Found existing installation: wrapt 1.17.2 Uninstalling wrapt-1.17.2: Successfully uninstalled wrapt-1.17.2 Attempting uninstall: protobuf Found existing installation: protobuf 5.29.4 Uninstalling protobuf-5.29.4: Successfully uninstalled protobuf-5.29.4 Attempting uninstall: packaging Found existing installation: packaging 24.2 Uninstalling packaging-24.2: Successfully uninstalled packaging-24.2 Attempting uninstall: numpy Found existing installation: numpy 2.0.2 Uninstalling numpy-2.0.2: Successfully uninstalled numpy-2.0.2 Attempting uninstall: keras Found existing installation: keras 3.8.0 Uninstalling keras-3.8.0: Successfully uninstalled keras-3.8.0 Attempting uninstall: tensorstore Found existing installation: tensorstore 0.1.72 Uninstalling tensorstore-0.1.72: Successfully uninstalled tensorstore-0.1.72 Attempting uninstall: ml-dtypes Found existing installation: ml-dtypes 0.4.1 Uninstalling ml-dtypes-0.4.1: Successfully uninstalled ml-dtypes-0.4.1 Attempting uninstall: jaxlib Found existing installation: jaxlib 0.5.1 Uninstalling jaxlib-0.5.1: Successfully uninstalled jaxlib-0.5.1 Attempting uninstall: tensorboard Found existing installation: tensorboard 2.18.0 Uninstalling tensorboard-2.18.0: Successfully uninstalled tensorboard-2.18.0 Attempting uninstall: jax Found existing installation: jax 0.5.2 Uninstalling jax-0.5.2: Successfully uninstalled jax-0.5.2 Attempting uninstall: orbax-checkpoint Found existing installation: orbax-checkpoint 0.11.10 Uninstalling orbax-checkpoint-0.11.10: Successfully uninstalled orbax-checkpoint-0.11.10 Attempting uninstall: tf-keras Found existing installation: tf_keras 2.18.0 Uninstalling tf_keras-2.18.0: Successfully uninstalled tf_keras-2.18.0 ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. google-cloud-bigquery 3.31.0 requires packaging>=24.2.0, but you have packaging 23.2 which is incompatible. dopamine-rl 4.1.2 requires tf-keras>=2.18.0, but you have tf-keras 2.15.1 which is incompatible. grpcio-status 1.71.0 requires protobuf<6.0dev,>=5.26.1, but you have protobuf 4.25.6 which is incompatible. tensorflow-text 2.18.1 requires tensorflow<2.19,>=2.18.0, but you have tensorflow 2.15.0 which is incompatible. Successfully installed dnspython-2.7.0 jax-0.4.34 jaxlib-0.4.34 keras-2.15.0 ml-dtypes-0.2.0 numpy-1.26.4 orbax-checkpoint-0.4.4 packaging-23.2 protobuf-4.25.6 pymongo-4.11.3 tensorboard-2.15.2 tensorflow-2.15.0 tensorflow-decision-forests-1.8.1 tensorflow-estimator-2.15.0 tensorflowjs-4.22.0 tensorstore-0.1.45 tf-keras-2.15.1 wrapt-1.14.1 wurlitzer-3.1.1
Step 2: Verify Installed Packages¶
In [ ]:
import pymongo
import numpy
import pandas as pd
import matplotlib
import sklearn
import tensorflow
import tensorflowjs
packages = [pymongo, numpy, pd, matplotlib, sklearn, tensorflow, tensorflowjs]
for package in packages:
print(f"{package.__name__}: {package.__version__}")
pymongo: 4.11.3 numpy: 1.26.4 pandas: 2.2.2 matplotlib: 3.10.0 sklearn: 1.6.1 tensorflow: 2.15.0 tensorflowjs: 4.22.0
Connecting to MongoDB and Retrieving Data¶
Step 1: Connect to MongoDB¶
In [ ]:
from pymongo import MongoClient
try:
uri = "mongodb+srv://biof3003digitalhealth01:qoB38jemj4U5E7ZL@cluster0.usbry.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0"
client = MongoClient(uri)
database = client["test"]
collection = database["records"]
# Check document count and load data into a DataFrame
count = collection.estimated_document_count()
print(f"Number of documents: {count}")
data = list(collection.find({}))
df = pd.DataFrame(data)
client.close()
except Exception as e:
raise Exception("The following error occurred: ", e)
Number of documents: 482
Step 2: Inspect the data¶
In [ ]:
print(df.head())
_id heartRate hrv \
0 67ac666618d0086a8921b8cf 64 {'sdnn': 588, 'confidence': 48}
1 67ac75ce1a3e22117370bc64 85 {'sdnn': 69, 'confidence': 95}
2 67ac75d21a3e22117370bc66 80 {'sdnn': 44, 'confidence': 97}
3 67ac75d31a3e22117370bc68 80 {'sdnn': 58, 'confidence': 96}
4 67ac75d81a3e22117370bc6a 86 {'sdnn': 88, 'confidence': 94}
confidence ppgData \
0 100.000000 [41.4, 40, 40, 37, 37, 37.2, 37.2, 36.2, 36.2,...
1 90.961931 [407.2, 407, 407, 409.4, 409.4, 410.4, 410.4, ...
2 94.740703 [413.8, 408, 408, 408.2, 408.2, 407.6, 407.6, ...
3 92.779961 [408.8, 408.8, 408.8, 408.8, 409.6, 409.6, 410...
4 88.820899 [412.6, 412.6, 414.4, 414.4, 414.4, 414.4, 416...
timestamp __v subjectId
0 2025-02-12 09:14:13.047 0 NaN
1 2025-02-12 10:19:53.801 0 NaN
2 2025-02-12 10:20:02.024 0 NaN
3 2025-02-12 10:20:03.069 0 NaN
4 2025-02-12 10:20:08.098 0 NaN
Visualizing PPG Signals¶
Step 1: Plot All Signals¶
In [ ]:
import matplotlib.pyplot as plt
def plot_ppg_signals(df):
total_samples = len(df)
cols = 6
rows = (total_samples + cols - 1) // cols # Ceiling division
plt.figure(figsize=(18, 1 * rows))
for idx in range(total_samples):
ax = plt.subplot(rows, cols, idx + 1)
plt.plot(df['ppgData'][idx], color="black", linewidth=1.5)
plt.title(f"Sample {idx}", fontsize=8)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)
plt.xticks([])
plt.yticks([])
plt.tight_layout(pad=1.0, w_pad=0.5, h_pad=1.5, rect=[0, 0, 1, 0.98])
plt.show()
plot_ppg_signals(df)
Step 2: Label and Color-Code Signals¶
In [ ]:
excellent_samples = list(range(16, 21)) + list(range(24, 29)) + list(range(30, 33)) + list(range(35, 67))
acceptable_samples = list(range(12, 16)) + [21, 29, 33, 34]
bad_samples = list(range(0, 12)) + list(range(22, 24)) + list(range(67, 207))
def plot_ppg_signals_with_labels(df, excellent_samples, acceptable_samples, bad_samples):
color_map = {
'excellent': '#00C0C7', # Turquoise
'acceptable': '#FFC300', # Yellow
'bad': '#C41E3A', # Red
'unlabeled': '#000000' # Black
}
label_map = {}
for idx in df.index:
if idx in excellent_samples:
label_map[idx] = 'excellent'
elif idx in acceptable_samples:
label_map[idx] = 'acceptable'
elif idx in bad_samples:
label_map[idx] = 'bad'
else:
label_map[idx] = 'unlabeled'
total_samples = len(df)
cols = 6
rows = (total_samples + cols - 1) // cols
plt.figure(figsize=(18, 1 * rows))
for idx in range(total_samples):
ax = plt.subplot(rows, cols, idx + 1)
label = label_map[idx]
color = color_map[label]
plt.plot(df['ppgData'][idx], color=color, linewidth=1.5)
plt.title(f"{label.capitalize()} Sample {idx}", fontsize=8)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)
plt.xticks([])
plt.yticks([])
legend_elements = [plt.Line2D([0], [0], color=color, label=label.capitalize()) for label, color in color_map.items() if any(v == label for v in label_map.values())]
plt.figlegend(handles=legend_elements, loc='center', bbox_to_anchor=(0.5, 1.02), ncol=4, fontsize=10)
plt.tight_layout(pad=1.0, w_pad=0.5, h_pad=1.5, rect=[0, 0, 1, 0.98])
plt.show()
# Example usage
plot_ppg_signals_with_labels(df, excellent_samples, acceptable_samples, bad_samples)
Extracting Features from PPG Signals¶
Step 1: Define Feature Extraction Function¶
In [ ]:
import numpy as np
from scipy.stats import entropy
from scipy.signal import welch
def extract_ppg_features(signal):
epsilon = 1e-7 # Small value for numerical stability
if len(signal) == 0: # Handle empty signals
return np.zeros(8) # Return a zero-vector of the same feature length
mean = np.mean(signal)
median = np.median(signal)
std = np.std(signal)
variance = np.var(signal)
diff = signal - mean
skewness = np.mean(np.power(diff, 3)) / (np.power(std, 3) + 1e-7)
kurtosis = np.mean(np.power(diff, 4)) / (np.power(std, 4) + 1e-7)
signal_range = np.max(signal) - np.min(signal)
zero_crossings = np.sum(np.diff(np.signbit(signal).astype(int)) != 0)
rms = np.sqrt(np.mean(np.square(signal)))
peak_to_peak = signal_range
# Frequency-Domain Features
fft_coeffs = np.fft.fft(signal)
fft_magnitudes = np.abs(fft_coeffs)[:len(signal)//2] # First half of FFT spectrum
fft_magnitudes /= (np.sum(fft_magnitudes) + epsilon) # Normalize magnitudes
dominant_freq = np.argmax(fft_magnitudes) # Most dominant frequency index
# Entropy (Measures randomness in signal)
hist, _ = np.histogram(signal, bins=10, density=True)
signal_entropy = entropy(hist + epsilon)
features = np.array([
mean, median, std, variance, skewness, kurtosis,
signal_range, zero_crossings, rms, peak_to_peak,
dominant_freq, signal_entropy
])
return features
Step 2: Prepare Dataset¶
In [ ]:
import tensorflow as tf
def prepare_dataset_with_features(df, excellent_samples, acceptable_samples, bad_samples):
# Convert labels into categorical values
y_labels = np.zeros(len(df), dtype=int)
y_labels[excellent_samples] = 2 # Excellent
y_labels[acceptable_samples] = 1 # Acceptable
y_labels[bad_samples] = 0 # Bad
y = tf.keras.utils.to_categorical(y_labels, num_classes=3)
# Ensure all signals are valid before processing
valid_signals = [signal if isinstance(signal, np.ndarray) and signal.size > 0 else np.zeros(100) for signal in df['ppgData']]
# Extract features
X = np.array([extract_ppg_features(signal) for signal in valid_signals])
return X, y
In [ ]:
X, y = prepare_dataset_with_features(df, excellent_samples, acceptable_samples, bad_samples)
Training a Classification Model¶
Step 1: Create and Compile the Model¶
In [ ]:
import tensorflow as tf
def create_improved_model(feature_dim=12):
input_layer = tf.keras.layers.Input(shape=(feature_dim,), dtype=tf.float32, name='feature_input')
# Batch Normalization BEFORE Activation
x = tf.keras.layers.BatchNormalization()(input_layer)
# Fully connected layers with ReLU activation
x = tf.keras.layers.Dense(256)(x) # Increased neurons
x = tf.keras.layers.ReLU()(x)
x = tf.keras.layers.Dropout(0.3)(x) # Slightly increased dropout
x = tf.keras.layers.Dense(128)(x) # Added layer
x = tf.keras.layers.ReLU()(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Dropout(0.3)(x)
x = tf.keras.layers.Dense(64)(x) # Increased neurons
x = tf.keras.layers.ReLU()(x)
x = tf.keras.layers.Dropout(0.2)(x)
x = tf.keras.layers.Dense(32)(x)
x = tf.keras.layers.ReLU()(x)
x = tf.keras.layers.Dropout(0.1)(x)
# Output Layer with Softmax for Classification
outputs = tf.keras.layers.Dense(3, activation='softmax', name='classification')(x)
# Create and compile model
model = tf.keras.Model(inputs=input_layer, outputs=outputs)
return model
Step 2: Train the Model¶
In [ ]:
from sklearn.model_selection import train_test_split
from tensorflow.keras.optimizers import SGD
def train_and_save_model_with_features(df, excellent_samples, acceptable_samples, bad_samples):
"""Train model with pre-extracted features"""
# Prepare dataset
X, y = prepare_dataset_with_features(df, excellent_samples, acceptable_samples, bad_samples)
# Split dataset
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
# Create and compile model
model = create_improved_model(feature_dim=12)
model.compile(optimizer=SGD(learning_rate=0.01, momentum=0.9),
loss='categorical_crossentropy',
metrics=['accuracy'])
# Create callbacks
callbacks = [
tf.keras.callbacks.EarlyStopping(
monitor='val_loss',
patience=10,
restore_best_weights=True
),
tf.keras.callbacks.ReduceLROnPlateau(
monitor='val_loss',
factor=0.5,
patience=5
),
tf.keras.callbacks.ModelCheckpoint(
'best_model.h5',
save_best_only=True,
monitor='val_accuracy'
)
]
# Train model
history = model.fit(
X_train,
y_train,
epochs=40,
batch_size=16,
validation_split=0.2,
callbacks=callbacks,
verbose=1
)
return model, history
In [ ]:
# Train the model
model, history = train_and_save_model_with_features(df, excellent_samples, acceptable_samples, bad_samples)
Epoch 1/40 20/20 [==============================] - 2s 23ms/step - loss: 0.7324 - accuracy: 0.8409 - val_loss: 0.4503 - val_accuracy: 0.8831 - lr: 0.0100 Epoch 2/40 20/20 [==============================] - 0s 6ms/step - loss: 0.4601 - accuracy: 0.8929 - val_loss: 0.4472 - val_accuracy: 0.8831 - lr: 0.0100 Epoch 3/40 1/20 [>.............................] - ETA: 0s - loss: 0.4611 - accuracy: 0.8750
/usr/local/lib/python3.11/dist-packages/keras/src/engine/training.py:3103: UserWarning: You are saving your model as an HDF5 file via `model.save()`. This file format is considered legacy. We recommend using instead the native Keras format, e.g. `model.save('my_model.keras')`.
saving_api.save_model(
20/20 [==============================] - 0s 5ms/step - loss: 0.4130 - accuracy: 0.8929 - val_loss: 0.4434 - val_accuracy: 0.8831 - lr: 0.0100 Epoch 4/40 20/20 [==============================] - 0s 6ms/step - loss: 0.4219 - accuracy: 0.8929 - val_loss: 0.4338 - val_accuracy: 0.8831 - lr: 0.0100 Epoch 5/40 20/20 [==============================] - 0s 5ms/step - loss: 0.4381 - accuracy: 0.8929 - val_loss: 0.4439 - val_accuracy: 0.8831 - lr: 0.0100 Epoch 6/40 20/20 [==============================] - 0s 6ms/step - loss: 0.4296 - accuracy: 0.8929 - val_loss: 0.4328 - val_accuracy: 0.8831 - lr: 0.0100 Epoch 7/40 20/20 [==============================] - 0s 6ms/step - loss: 0.4277 - accuracy: 0.8929 - val_loss: 0.4534 - val_accuracy: 0.8831 - lr: 0.0100 Epoch 8/40 20/20 [==============================] - 0s 7ms/step - loss: 0.4018 - accuracy: 0.8929 - val_loss: 0.4532 - val_accuracy: 0.8831 - lr: 0.0100 Epoch 9/40 20/20 [==============================] - 0s 7ms/step - loss: 0.3940 - accuracy: 0.8929 - val_loss: 0.4385 - val_accuracy: 0.8831 - lr: 0.0100 Epoch 10/40 20/20 [==============================] - 0s 6ms/step - loss: 0.4148 - accuracy: 0.8929 - val_loss: 0.4364 - val_accuracy: 0.8831 - lr: 0.0100 Epoch 11/40 20/20 [==============================] - 0s 6ms/step - loss: 0.4196 - accuracy: 0.8929 - val_loss: 0.4296 - val_accuracy: 0.8831 - lr: 0.0100 Epoch 12/40 20/20 [==============================] - 0s 6ms/step - loss: 0.4086 - accuracy: 0.8929 - val_loss: 0.4368 - val_accuracy: 0.8831 - lr: 0.0100 Epoch 13/40 20/20 [==============================] - 0s 6ms/step - loss: 0.3927 - accuracy: 0.8929 - val_loss: 0.4395 - val_accuracy: 0.8831 - lr: 0.0100 Epoch 14/40 20/20 [==============================] - 0s 6ms/step - loss: 0.4084 - accuracy: 0.8929 - val_loss: 0.4391 - val_accuracy: 0.8831 - lr: 0.0100 Epoch 15/40 20/20 [==============================] - 0s 7ms/step - loss: 0.4146 - accuracy: 0.8929 - val_loss: 0.4392 - val_accuracy: 0.8831 - lr: 0.0100 Epoch 16/40 20/20 [==============================] - 0s 8ms/step - loss: 0.4090 - accuracy: 0.8929 - val_loss: 0.4508 - val_accuracy: 0.8831 - lr: 0.0100 Epoch 17/40 20/20 [==============================] - 0s 7ms/step - loss: 0.3964 - accuracy: 0.8929 - val_loss: 0.4558 - val_accuracy: 0.8831 - lr: 0.0050 Epoch 18/40 20/20 [==============================] - 0s 6ms/step - loss: 0.3893 - accuracy: 0.8929 - val_loss: 0.4451 - val_accuracy: 0.8831 - lr: 0.0050 Epoch 19/40 20/20 [==============================] - 0s 7ms/step - loss: 0.4029 - accuracy: 0.8929 - val_loss: 0.4492 - val_accuracy: 0.8831 - lr: 0.0050 Epoch 20/40 20/20 [==============================] - 0s 6ms/step - loss: 0.4016 - accuracy: 0.8929 - val_loss: 0.4532 - val_accuracy: 0.8831 - lr: 0.0050 Epoch 21/40 20/20 [==============================] - 0s 7ms/step - loss: 0.3941 - accuracy: 0.8929 - val_loss: 0.4439 - val_accuracy: 0.8831 - lr: 0.0050
Step 3: Visualize the Training Process¶
In [ ]:
def plot_training_history(history):
"""Plot training history"""
plt.figure(figsize=(12, 4))
# Plot accuracy
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
# Plot loss
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.tight_layout()
plt.savefig('training_history.png')
plt.show()
In [ ]:
# Plot the training history
plot_training_history(history)
Step 4: Save the Trained Model¶
In [ ]:
import tensorflowjs as tfjs
model.save('final_model.h5') # Save the entire model
tfjs.converters.save_keras_model(model, 'tfjs_model')